It was hard to find a proper classwise precision recall logger on the web. Here is a simple one for you to use:
In your experiment:
class Experiment(pl.LightningModule):
def __init__(self, model, loss, n_classes, dual_images=False) -> None:
super().__init__()
...
...
def _calc_pr(self, outputs):
y_true = torch.cat([x["val_true"] for x in outputs])
y_pred = torch.cat([x["val_preds"] for x in outputs])
class_tp = torch.zeros(self.n_classes)
class_fn = torch.zeros(self.n_classes)
class_total = torch.zeros(self.n_classes)
for i in range(len(y_true)):
if y_true[i] == y_pred[i]:
class_tp[y_true[i]] += 1
else:
class_fn[y_true[i]] += 1
class_total[y_pred[i]] += 1
classwise_precision = class_tp / (class_tp + class_fn)
classwise_recall = class_tp / class_total
return classwise_precision, classwise_recall
def validation_epoch_end(self, outputs):
...
# pr, rc
classwise_precision, classwise_recall = self._calc_pr(outputs)
self.log_dict({f"class_{i}_precision": val for i, val in enumerate(classwise_precision.tolist())}, sync_dist=True)
self.log_dict({f"class_{i}_recall": val for i, val in enumerate(classwise_recall.tolist())}, sync_dist=True)
...